import time
from tqdm import tqdm

import numpy as np
from sklearn.utils import resample


class C2UpBandit(object):

    def __init__(self, Xs, ys_estimate,
                 treat=None, ys_true=None,
                 rng_sampling=42, rng_feedback=7):
        n_samples = len(Xs)
        self.Xs = np.hstack([Xs, np.ones(n_samples).reshape(-1, 1)])
        self.ys_estimate = ys_estimate
        self.treat = treat
        self.ys_true = ys_true
        self.init_random_seed(rng_sampling, rng_feedback)
        self.counter = 0
        self.shuffle_dataset()
        self.optimal_uplifts = []
        self.rewards_optimal_arm = []
        self.rewards_optimal_arm_from_data = []
        self.total_uplifts = []
        self.rewards_treat_all = []
        self.rewards_treat_all_from_data = []

    def init_random_seed(self, rng_sampling, rng_feedback):
        self.rng_sampling = np.random.RandomState(rng_sampling)
        self.rng_feedback = np.random.default_rng(rng_feedback)
        self.counter = 0

    def shuffle_dataset(self):
        random_state = self.rng_sampling
        if self.treat is None:
            self.Xs_shuffled, self.ys_estimate_shuffled = resample(
                self.Xs, self.ys_estimate,
                replace=False, random_state=random_state)
        else:
            (self.Xs_shuffled, self.ys_estimate_shuffled,
             self.treat_shuffled, self.ys_true_shuffled) = resample(
                self.Xs, self.ys_estimate, self.treat, self.ys_true,
                replace=False, random_state=random_state)

    def sample_contexts(self, n_samples, random_state=None):
        if self.counter + n_samples > len(self.Xs):
            self.shuffle_dataset()
            self.counter = 0
        sampled = np.s_[self.counter:self.counter + n_samples]
        self.counter += n_samples
        self.Xs_sampled = self.Xs_shuffled[sampled]
        self.ys_estimate_sampled = self.ys_estimate_shuffled[sampled]
        if self.treat is not None:
            self.treat_sampled = self.treat_shuffled[sampled]
            self.ys_true_sampled = self.ys_true_shuffled[sampled]
        return self.Xs_sampled

    def sample_contexts_direct(self, n_samples, random_state=None):
        if self.treat is None:
            self.Xs_sampled, self.ys_estimate_sampled = resample(
                self.Xs, self.ys_estimate,
                n_samples=n_samples, replace=False, random_state=random_state)
            return self.Xs_sampled
        (self.Xs_sampled, self.ys_estimate_sampled,
         self.treat_sampled, self.ys_true_sampled) = resample(
            self.Xs, self.ys_estimate, self.treat, self.ys_true,
            n_samples=n_samples, replace=False, random_state=random_state)
        return self.Xs_sampled

    def feedback(self, arm):
        n_samples = len(arm)
        n_treatments = np.max(arm).astype(int)
        # probs = np.array([y_estimate[trt]
        #                   for y_estimate, trt in zip(self.ys_estimate_sampled, arm)])
        probs = np.zeros(n_samples)
        for trt in range(n_treatments+1):
            treated = arm == trt
            probs[treated] = self.ys_estimate_sampled[:, trt][treated]
        realization = (self.rng_feedback.random(n_samples) < probs).astype('float')
        uplift = np.sum(probs - self.ys_estimate_sampled[:, 0])
        reward = np.sum(realization)
        if self.treat is not None:
            reward = 0
            # Can cause error if there is no groundtruth data with a specific treatment
            for trt in range(n_treatments+1):
                overlap = np.logical_and(arm == trt, self.treat_sampled == trt)
                feedback_groundtruth = self.ys_true_sampled[overlap]
                realization[overlap] = feedback_groundtruth
                # Compute expected reward solely from groundtruth data
                reward += np.mean(feedback_groundtruth)*np.sum(arm == trt)
            # for i in range(n_samples):
            #     if arm[i] == self.treat_sampled[i]:
            #         realization[i] = self.ys_true_sampled[i]
        return realization, uplift, reward

    # Suppose there is only one treatment
    def update_treat_optimal(self, budget, use_uplift=True):
        n_samples = len(self.Xs_sampled)
        if use_uplift:
            uplift = self.ys_estimate_sampled[:, 1] - self.ys_estimate_sampled[:, 0]
        else:
            uplift = self.ys_estimate_sampled[:, 1]
        # budget = min(budget, np.sum(uplift > 0))
        selected = np.ix_(np.argsort(uplift)[-budget:])
        arm_opt = np.zeros(n_samples)
        arm_opt[selected] = 1
        realization, optimal_uplift, reward = self.feedback(arm_opt)
        self.rewards_optimal_arm.append(np.sum(realization))
        self.optimal_uplifts.append(optimal_uplift)
        self.rewards_optimal_arm_from_data.append(reward)

    def update_treat_all(self):
        n_samples = len(self.Xs_sampled)
        arm = np.ones(n_samples)
        realization, uplift, reward = self.feedback(arm)
        self.rewards_treat_all.append(np.sum(realization))
        self.total_uplifts.append(uplift)
        self.rewards_treat_all_from_data.append(reward)


def interact_c2(bandit, learner, n_rounds, n_samples, budget, logt=False, verbose=1):
    start = time.time()
    steps = range(n_rounds) if verbose < 0 else tqdm(range(n_rounds))
    for step in steps:
        contexts = bandit.sample_contexts(n_samples)
        if logt:
            arm = learner.act(contexts, budget, step=step)
        else:
            arm = learner.act(contexts, budget)
        feedback, uplift, reward = bandit.feedback(arm)
        learner.update(arm, contexts, feedback, uplift, reward)
        if verbose > 0:
            print(bandit.optimal_expected_reward(budget))
            print((step, np.sum(feedback), uplift))
            if step % 10 == 0:
                print(f'time: {time.time()-start}')
        if verbose == -1 and step % 100 == 0:
            print(step)
            print(f'time: {time.time()-start}')
    return learner.rewards, learner.uplifts, learner.arm_his


def compute_regret(learner, optimal_uplifts, budget=0, cost=0):
    n_rounds = len(learner.uplifts)
    n_selelected_arms_per_round = np.sum(np.array(learner.arm_his), axis=1)
    uplifts = np.array(learner.uplifts) - cost * n_selelected_arms_per_round
    optimal_uplifts = np.array(optimal_uplifts) - budget * cost
    simple_regret = optimal_uplifts[:n_rounds] - uplifts
    cumulative_regret = np.cumsum(simple_regret)
    return simple_regret, cumulative_regret


def compute_regret_realized(learner, rewards_optimal_arm):
    n_rounds = len(learner.uplifts)
    simple_regret = np.array(rewards_optimal_arm[:n_rounds]) - np.array(learner.rewards)
    cumulative_regret = np.cumsum(simple_regret)
    return simple_regret, cumulative_regret


def compute_regret_treat_all(treat_all, optimal, n_samples_per_round, budget, cost):
    treat_all = np.array(treat_all) - n_samples_per_round * cost
    optimal = np.array(optimal) - budget * cost
    simple_regret = optimal - treat_all
    cumulative_regret = np.cumsum(simple_regret)
    return simple_regret, cumulative_regret
